# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/utils/utils.py
"""

import contextlib
import filelock
import math
import os
import tempfile
from urllib.parse import urlparse, unquote

import paddle

from autoshape.utils import logger


@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
    '''Generate a temporary directory'''
    directory = seg_env.TMP_HOME if not directory else directory
    with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
        yield _dir



def load_pretrained_model(model, pretrained_model):

    if os.path.exists(pretrained_model):
        para_state_dict = paddle.load(pretrained_model)

        model_state_dict = model.state_dict()
        keys = model_state_dict.keys()
        num_params_loaded = 0
        for k in keys:
            if k not in para_state_dict:
                logger.warning("{} is not in pretrained model".format(k))
            elif list(para_state_dict[k].shape) != list(
                    model_state_dict[k].shape):
                logger.warning(
                    "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
                    .format(k, para_state_dict[k].shape,
                            model_state_dict[k].shape))
            else:
                model_state_dict[k] = para_state_dict[k]
                num_params_loaded += 1
        model.set_dict(model_state_dict)
        logger.info("There are {}/{} variables loaded into {}.".format(
            num_params_loaded, len(model_state_dict),
            model.__class__.__name__))

    else:
        raise ValueError(
            'The pretrained model directory is not Found: {}'.format(
                pretrained_model))


def resume(model, optimizer, resume_model):
    if resume_model is not None:
        logger.info('Resume model from {}'.format(resume_model))
        if os.path.exists(resume_model):
            resume_model = os.path.normpath(resume_model)
            ckpt_path = os.path.join(resume_model, 'model.pdparams')
            para_state_dict = paddle.load(ckpt_path)
            ckpt_path = os.path.join(resume_model, 'model.pdopt')
            opti_state_dict = paddle.load(ckpt_path)
            model.set_state_dict(para_state_dict)
            optimizer.set_state_dict(opti_state_dict)

            iter = resume_model.split('_')[-1]
            iter = int(iter)
            return iter
        else:
            raise ValueError(
                'Directory of the model needed to resume is not Found: {}'.
                format(resume_model))
    else:
        logger.info('No model needed to resume.')
